import numpy
import pymysql as mysql
import pandas as pd


def query(code, start_date, end_date):
    # 打开数据库连接
    conn = mysql.connect(host="106.52.210.202", port=3306, user="kaifamiao", passwd="123456", db="hs_stockdata")
    # 使用 cursor() 方法创建一个游标对象 cursor
    cur = conn.cursor()
    # 使用 execute()  方法执行 SQL 查询
    sql = "SELECT * FROM stockhistory WHERE 股票代码 = %s and 交易日期 >= %s and 交易日期 <= %s "
    cur.execute(sql, (code, start_date, end_date))
    # 使用 fetchall() 方法获取查询结果
    data = cur.fetchall()
    # 将查询结果转换为 DataFrame
    result = pd.DataFrame(data)
    # 将第 3 列交易日期数据从 datetime 转换为 str 类型
    if len(result) > 0:
        result[3] = result[3].astype(str)

    # 关闭数据库连接
    cur.close()
    conn.close()

    return result


# 查询所有的股票
def query_code():
    pass

if __name__ == "__main__":
    data = query("sh600000", "2023-05-04", "2023-05-08")

    # 日期列转换为 list
    # date = numpy.array(data[3]).tolist()
    # print(date)

    r = data[[7, 8, 9, 10]]
    y = numpy.array(r).tolist()
    print(y)